import numpy as np
import matplotlib.pyplot as plt


n_exp = 5
n_trees = 5
n_simulations = 1000

# Define the specific k,d combinations to plot (matching the attached figure)
k = [16, 200, 14, 16, 16, 200]
d = [1, 1, 3, 3, 4, 2]

exploration_coeff = 1.
tau = .1

# All algorithms to plot
algs = ['uct', 'power-uct', 'dng', 'fixed-depth-mcts', 'ments', 'rents', 'tents', 'dents', 'catso', 'patso']
algs_legend = ["UCT", "Power-UCT", "DNG", "Fixed-Depth-MCTS", "MENTS", "RENTS", "TENTS", "BTS", "CATSO", "PATSO"]

# Colors for each algorithm (matching the attached plot)
algs_legend_color = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
                     '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']

# Alpha values
alphas = [8]  # Default alpha for most algorithms
default_atom = 10  # Default number of atoms

folder_name = './logs/expl_%.2f_tau_%.2f' % (exploration_coeff, tau)

# Create figure
plt.figure(figsize=(15, 5))

# Plot each k,d combination
for count_plot, (kk, dd) in enumerate(zip(k, d)):
    max_diff_uct = 0

    plt.subplot(1, len(k), count_plot + 1)
    plt.title('k=%d  d=%d' % (kk, dd), fontsize='medium')

    # Set labels
    if count_plot == 0:
        plt.ylabel('Value Estimation Error', fontsize='medium')
    plt.xlabel('# Simulations', fontsize='small')

    # Plot each algorithm
    for alg_idx, alg in enumerate(algs):
        # Determine the correct parameters for each algorithm
        alpha = 1  # Default alpha
        atoms = default_atom
        current_exploration_coeff = exploration_coeff
        current_tau = tau

        # Special handling for different algorithms
        if alg in {'uct', 'dng', 'fixed-depth-mcts', 'ments', 'rents', 'tents', 'dents'}:
            alpha = 1
        elif alg in {'power-uct'}:
            alpha = 1
        elif alg in {'catso', 'patso'}:
            alpha = 10  # As specified in your original code

        # Special exploration coefficients
        if alg in {'uct', 'fixed-depth-mcts'}:
            current_exploration_coeff = 0.05
            folder_name_alg = './logs/expl_%.2f_tau_%.2f' % (current_exploration_coeff, current_tau)
        elif alg == 'dents':  # This is BTS in the legend
            current_exploration_coeff = 0.75
            current_tau = 0.5
            folder_name_alg = './logs/expl_%.2f_tau_%.2f' % (current_exploration_coeff, current_tau)
        else:
            folder_name_alg = folder_name

        subfolder_name = folder_name_alg + '/k_%d_d_%d' % (kk, dd)

        try:
            # Load the data
            diff_uct = np.load(subfolder_name + '/diff_uct_%s_%f_%d.npy' % (alg, alpha, atoms))
            avg_diff_uct = diff_uct.mean(0)

            # Plot with appropriate style
            if alg in {"ments", "rents", "tents"}:
                plt.plot(avg_diff_uct, linewidth=1.5, linestyle='dotted',
                         color=algs_legend_color[alg_idx], label=algs_legend[alg_idx] if count_plot == 0 else "")
            else:
                plt.plot(avg_diff_uct, linewidth=1.5,
                         color=algs_legend_color[alg_idx], label=algs_legend[alg_idx] if count_plot == 0 else "")

            # Add error bars
            err = 2 * np.std(diff_uct.reshape(n_exp * n_trees, n_simulations),
                             axis=0) / np.sqrt(n_exp * n_trees)
            plt.fill_between(np.arange(n_simulations), avg_diff_uct - err, avg_diff_uct + err,
                             alpha=.1, color=algs_legend_color[alg_idx])

            max_diff_uct = max(max_diff_uct, avg_diff_uct.max())

        except FileNotFoundError:
            print(f"Warning: Could not find data for {alg} with k={kk}, d={dd}, alpha={alpha}, atoms={atoms}")
            continue

    # Set grid and limits
    plt.grid(True, alpha=0.3)
    plt.ylim(0, max_diff_uct * 1.1)  # Add 10% padding

    # Set x-axis ticks
    plt.xticks([0, 500, 1000], ['0', '500', '1000'], fontsize='small')
    plt.xlim(0, 1000)

# Adjust layout
plt.tight_layout()

# Add legend below the plots
plt.figlegend(algs_legend, fontsize='small', loc='lower center', bbox_to_anchor=(0.5, -0.15),
              ncol=5, frameon=False)

# Adjust subplot spacing to make room for legend
plt.subplots_adjust(bottom=0.15)

# Save the figure
plt.savefig("results.pdf", bbox_inches='tight', pad_inches=0.1)
plt.show()